{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eff55242",
   "metadata": {},
   "source": [
    "# {class}`~tvm.relax.training.AppendLoss`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ec2024ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm.testing\n",
    "from tvm import TVMError\n",
    "from tvm.ir.base import assert_structural_equal\n",
    "from tvm.script import relax as R, ir as I\n",
    "from tvm.relax.training import AppendLoss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b0ab714",
   "metadata": {},
   "source": [
    "## 简单测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0a001cd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "@I.ir_module\n",
    "class Before:\n",
    "    @R.function\n",
    "    def main(x: R.Tensor((3, 3), \"float32\"), y: R.Tensor((3, 3), \"float32\")):\n",
    "        with R.dataflow():\n",
    "            gv0 = x + y\n",
    "            R.output(gv0)\n",
    "        return gv0\n",
    "\n",
    "@R.function\n",
    "def loss(arg1: R.Tensor((3, 3), \"float32\")):\n",
    "    with R.dataflow():\n",
    "        gv0 = R.sum(arg1)\n",
    "        R.output(gv0)\n",
    "    return gv0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "15b79ef9",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main_loss</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            gv0_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(gv0, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv0_1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv0_1\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, y)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv0)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv0\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "After = AppendLoss(\"main\", loss)(Before)\n",
    "After.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da70930e",
   "metadata": {},
   "source": [
    "## 测试多主干\n",
    "\n",
    "测试带多个主干(backbone)输出的场景"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c155f670",
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main_loss</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(x, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            gv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(y, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            gv0_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(gv0, gv1)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv0_1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv0_1\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), y: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(x, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            gv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(y, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv0, gv1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (gv0, gv1)\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "@I.ir_module\n",
    "class Before:\n",
    "    @R.function\n",
    "    def main(x: R.Tensor((3, 3), \"float32\"), y: R.Tensor((3, 3), \"float32\")):\n",
    "        with R.dataflow():\n",
    "            gv0 = R.sum(x)\n",
    "            gv1 = R.sum(y)\n",
    "            R.output(gv0, gv1)\n",
    "        return gv0, gv1\n",
    "\n",
    "@R.function\n",
    "def loss(arg1: R.Tensor((), \"float32\"), arg2: R.Tensor((), \"float32\")):\n",
    "    with R.dataflow():\n",
    "        gv0 = R.add(arg1, arg2)\n",
    "        R.output(gv0)\n",
    "    return gv0\n",
    "\n",
    "After = AppendLoss(\"main\", loss, 2)(Before)\n",
    "After.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76628b9a",
   "metadata": {},
   "source": [
    "## 测试附加参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ffe8fa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Before:\n",
    "    # 原始主干网络返回三个输出：\n",
    "    # gv0: 输入张量的求和结果（标量）\n",
    "    # gv1: 输入自身的相加结果（3x3矩阵） \n",
    "    # gv2: 原始输入张量（3x3矩阵）\n",
    "    @R.function\n",
    "    def main(x: R.Tensor((3, 3), \"float32\")):\n",
    "        with R.dataflow():\n",
    "            gv0 = R.sum(x)\n",
    "            gv1 = R.add(x, x)  # 矩阵自相加\n",
    "            gv2 = x            # 保留原始输入\n",
    "            R.output(gv0, gv1, gv2)\n",
    "        return gv0, gv1, gv2\n",
    "\n",
    "# 损失函数接收三个参数：\n",
    "# - arg1: 主干第一个输出（标量）\n",
    "# - arg2: 主干第二个输出（矩阵）\n",
    "# - arg3: 额外参数（矩阵）\n",
    "@R.function\n",
    "def loss(arg1, arg2, arg3):\n",
    "    with R.dataflow():\n",
    "        gv0 = R.add(arg2, arg3)  # 矩阵相加（使用额外参数）\n",
    "        gv1 = R.sum(gv0)         # 求和得到损失值\n",
    "        R.output(gv1)\n",
    "    return gv1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d327d363",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main_loss</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), arg3: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(x, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            gv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, x)\n",
       "            gv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> x\n",
       "            gv0_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(gv1, arg3)\n",
       "            gv1_1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(gv0_1, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv2, gv1_1)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (gv1_1, gv2)\n",
       "\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(x: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tuple(R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            gv0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>sum(x, axis<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>, keepdims<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">False</span>)\n",
       "            gv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>add(x, x)\n",
       "            gv2: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">3</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> x\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv0, gv1, gv2)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> (gv0, gv1, gv2)\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "After = AppendLoss(\"main\", loss, 2)(Before)\n",
    "After.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9d508ae",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
